{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "7cc982d1",
   "metadata": {},
   "source": [
    "# Re-map FLAVA checkpoint\n",
    "\n",
    "Modifying FLAVA's components can cause existing model checkpoints to go out of sync with the updated architecture. This notebook shows how to load the existing checkpoint, re-map the old layers to the new layers, and save the new checkpoint.\n",
    "\n",
    "To upload a new checkpoint, you must have access to the PyTorch AWS S3 account, and manually upload it from a local copy."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "411e4191",
   "metadata": {},
   "source": [
    "### Load original model\n",
    "\n",
    "Load the existing checkpoint into the FLAVA class to see what the architecture currently is."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "88ee917b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torchmultimodal.models.flava.model import flava_model_for_classification, flava_model_for_pretraining\n",
    "\n",
    "# flava_classification = flava_model_for_classification(num_classes=3)\n",
    "flava_pretraining = flava_model_for_pretraining(pretrained_model_key='flava_full')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5f00b369",
   "metadata": {},
   "source": [
    "### Print summary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cc286394",
   "metadata": {},
   "outputs": [],
   "source": [
    "flava_pretraining"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0d774455",
   "metadata": {},
   "source": [
    "### Mapping function\n",
    "\n",
    "Replace this function with the code needed to map the old layer weights to the new layer weights."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "cc9e4537",
   "metadata": {},
   "outputs": [],
   "source": [
    "import re\n",
    "\n",
    "def map_state_dict(state_dict):\n",
    "    mapped_state_dict = {}\n",
    "    for param, val in state_dict.items():\n",
    "        res = re.search('attention.attention', param)\n",
    "        if res:\n",
    "            idx = res.start()\n",
    "            new_param = param[:idx] + param[idx+10:]\n",
    "        else:\n",
    "            new_param = param\n",
    "        mapped_state_dict[new_param] = val\n",
    "    return mapped_state_dict"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "29870590",
   "metadata": {},
   "source": [
    "### Load old state dict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "41f64d26",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load from url, replace this path if it changes\n",
    "# old_model_url = 'https://download.pytorch.org/models/multimodal/flava/flava_model.pt'\n",
    "# old_state_dict = torch.hub.load_state_dict_from_url(old_model_url)\n",
    "\n",
    "# Or get from loaded model\n",
    "old_state_dict = flava_pretraining.model.state_dict()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "75322113",
   "metadata": {},
   "source": [
    "### Perform re-mapping"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "17363ae8",
   "metadata": {},
   "outputs": [],
   "source": [
    "#new_state_dict = map_state_dict(old_state_dict)\n",
    "new_state_dict = old_state_dict"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d94c4133",
   "metadata": {},
   "source": [
    "### Save updated checkpoint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "bc6baad9",
   "metadata": {},
   "outputs": [],
   "source": [
    "save_path = '/Users/rafiayub/flava_model.pt'\n",
    "torch.save(new_state_dict, save_path)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
